#include <stdio.h>
#include <iostream>
#include "src/libeasynn.h"

evaluation* add_test(program *prog){
int inputs0[] = {};
append_expression(prog, 0, "a", "Input", inputs0, 0);
int inputs1[] = {0, 0};
//相减和相乘操作只需要将Add替换成 Sub和Mul即可
append_expression(prog, 1, "", "Add", inputs1, 2);

evaluation *eval = build(prog);
//add_kwargs_double(eval, "a", 5);
int dim_ = 2;
size_t shape_[] = {2,3};
double data_[] = {1,2,3,4,5,6};
add_kwargs_ndarray(eval,"a",dim_, shape_, data_);
return eval;
}

evaluation* add_test1(program *prog){
int inputs0[] = {};
append_expression(prog, 0, "a", "Input", inputs0, 0);
append_expression(prog, 1, "b", "Input", inputs0, 0); // create new expression with id 1

int inputs1[] = {0, 1};
// this expression takes expression 1 and 2 as inputs
append_expression(prog, 2, "", "Add", inputs1, 2);
evaluation* eval = build(prog);
add_kwargs_double(eval, "a", 5);
add_kwargs_double(eval, "b", 6);
return eval;
}

evaluation* relu_test(program *prog){
int inputs0[] = {};
append_expression(prog,0,"a","Input",inputs0,0);

int inputs1[] = {0};
append_expression(prog, 2, "", "Relu", inputs1, 1);
evaluation* eval = build(prog);
//add_kwargs_double(eval, "a", -5);
int dim_ = 2;
size_t shape_[] = {2,3};
double data_[] = {11,-2,0,4,-5,6};
add_kwargs_ndarray(eval,"a",dim_, shape_, data_);
return eval;
}


evaluation* flatten_test(program *prog){
int inputs0[] = {};
append_expression(prog,0,"a","Input",inputs0,0);

int inputs1[] = {0};
append_expression(prog, 1, "", "Flatten", inputs1, 1);

evaluation* eval = build(prog);
//add_kwargs_double(eval, "a", -5);
int dim_ = 4;
size_t shape_[] = {2,3,2,2};
double data_[] = {1,11,111,
2,22,222,
3,33,333,
4,44,444,

1,11,111,
2,22,222,
3,33,333,
4,44,444
};
add_kwargs_ndarray(eval,"a",dim_, shape_, data_);
return eval;
}


evaluation* Input2d_test(program *prog){
// int inputs0[] = {};
// append_expression(prog,0,"a","Input",inputs0,0);

int inputs1[] = {0};
append_expression(prog, 0, "a", "Input2d", inputs1, 1);

evaluation* eval = build(prog);
int dim_ = 4;
size_t shape_[] = {2,3,2,2};
double data_[] = {1,11,111,
2,22,222,
3,33,333,
4,44,444,

1,11,111,
2,22,222,
3,33,333,
4,44,444
};
add_kwargs_ndarray(eval,"a",dim_, shape_, data_);
return eval;
}



evaluation* linear_test(program *prog){

int inputs0[] = {};
//append_expression(prog, 0, "w", "Input", inputs0, 0);
append_expression(prog, 0, "x", "Input", inputs0, 0);
//append_expression(prog, 2, "b", "Input", inputs0, 0); // create new expression with id 1

int inputs1[] = {0};
// this expression takes expression 1 and 2 as inputs
append_expression(prog, 1, "", "Linear", inputs1, 1);
int dim_w = 2;
size_t shape_w[] = {2,4};
double data_w[] = {1,2,1,2,
1,2,1,2};
add_op_param_ndarray(prog,"weight",dim_w,shape_w,data_w);
int dim_b = 1;
size_t shape_b[] = {2,};
double data_b[] = {1,2};
add_op_param_ndarray(prog,"bias",dim_b, shape_b, data_b);

evaluation* eval = build(prog);

int dim_x = 2;
size_t shape_x[] = {3,4};
double data_x[] = {1,2,1,2,
1,2,1,2,
1,2,1,2,};
add_kwargs_ndarray(eval,"x",dim_x, shape_x, data_x);
return eval;


}


evaluation* maxpool2d_test(program *prog){
int inputs0[] = {};
//append_expression(prog, 0, "w", "Input", inputs0, 0);
append_expression(prog, 0, "x", "Input", inputs0, 0);
//append_expression(prog, 2, "b", "Input", inputs0, 0); // create new expression with id 1

int inputs1[] = {0};
// this expression takes expression 1 and 2 as inputs
append_expression(prog, 1, "", "MaxPool2d", inputs1, 1);
double kernal_size = 2;
double stride = 2;

add_op_param_double(prog,"kernel_size",kernal_size);
add_op_param_double(prog,"stride",stride);

evaluation* eval = build(prog);
int dim_x = 4;
size_t shape_x[] = {2,3,2,2};
double data_x[] = { 1,11,111,1111,
                    2,22,222,2222,
                    3,33,333,3333,

                    4,44,444,4444,
                    5,55,555,5555,
                    6,66,666,6666};
add_kwargs_ndarray(eval,"x",dim_x, shape_x, data_x);
return eval;
}


evaluation* conv2d_test(program *prog){
    int inputs0[] = {};
    append_expression(prog, 0, "x", "Input", inputs0, 0);
    //append_expression(prog, 2, "b", "Input", inputs0, 0); // create new expression with id 1

    int inputs1[] = {0};
    // this expression takes expression 1 and 2 as inputs
    append_expression(prog, 1, "", "Conv2d", inputs1, 1);
    double kernal_size = 3;
    double in_channels = 1;
    double out_channels = 1;
    add_op_param_double(prog,"kernel_size",kernal_size);
    add_op_param_double(prog,"in_channels",kernal_size);
    add_op_param_double(prog,"out_channels",kernal_size);

    int dim_w = 4;
    size_t shape_w[] = {1,1,2,2};
    double data_w[] = {1,2,
    1,2};
    add_op_param_ndarray(prog,"weight",dim_w,shape_w,data_w);
    int dim_b = 1;
    size_t shape_b[] = {1,};
    double data_b[] = {1,};
    add_op_param_ndarray(prog,"bias",dim_b, shape_b, data_b);


    evaluation* eval = build(prog);
    int dim_x = 4;
    size_t shape_x[] = {2,1,3,3};
    double data_x[] = {1,11,111,
    2,22,222,
    3,33,333,

    4,44,444,
    5,55,555,
    6,66,666};
    add_kwargs_ndarray(eval,"x",dim_x, shape_x, data_x);
    return eval;
}






int main()
{

program *prog = create_program();
//测试相同数相加 减 乘
//evaluation* eval = add_test(prog);
//测试不同数相加 减 乘
//evaluation* eval = add_test1(prog);
//测试relu
//evaluation* eval = relu_test(prog);


//测试flatten
//evaluation* eval = flatten_test(prog);

//测试 Input2d NHWC => NCHW
//evaluation* eval = Input2d_test(prog);



//测试linear ax+b
//evaluation* eval = linear_test(prog);


//测试maxpool2d
evaluation* eval = maxpool2d_test(prog);




//测试conv2d  待续
//evaluation* eval = maxpool2d_test(prog);


int dim = 0;
size_t *shape = nullptr;
double *data = nullptr;
std::vector<double> a = execute_cpp(eval, &dim, &shape, &data);
for(int i = 0; i < a.size(); i++){
std::cout << a[i] << std::endl;
}


// if (execute(eval, &dim, &shape, &data) != 0)
// {
// printf("evaluation fails\n");
// return -1;
// }
// auto a = prog;
// std::cout << "dim:" << dim << std::endl;
// if (dim == 0)
// printf("res = %f\n", data[0]);
// else
// printf("result as tensor is not supported yet\n");
return 0;
}